Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unblock Llama2 ONNX export w/ sdpa by falling back to manual impl #28823

Closed
wants to merge 1 commit into from

Conversation

BowenBao
Copy link
Contributor

@BowenBao BowenBao commented Feb 1, 2024

What does this PR do?

Unblocks Llama2 ONNX export with sdpa by falling back to manual implementation.

ValueError: Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument attn_implementation="eager" or pass an attention_mask input when tracing the model.

Fixes #28610

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@fxmarty

Copy link

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Maybe add a unit test for the torch.jit.trace case?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! It's gonna be a bit hard to merge this. Would you mind checking if #27931 fixes the issue? It shall be merged before and should simplify all of that logic

@BowenBao
Copy link
Contributor Author

BowenBao commented Feb 3, 2024

Hi @ArthurZucker, I have validated the issue is fixed under your PR, thanks! Do you have an ETA when it will get merged? Our workstreams have been blocked by this issue for a while, we need to resolve this export issue asap.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 5, 2024

This week 😉 Waiting for @gante's green light and will merge #27931 (it was not clear)

@@ -673,12 +673,22 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
_jit_tracing = torch.jit.is_tracing()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that we call torch.jit.is_tracing as many times as there are layers.

@fxmarty
Copy link
Contributor

fxmarty commented Feb 5, 2024

I don't understand why this change is necessary. The error that is normally raised

ValueError: Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument attn_implementation="eager" or pass an attention_mask input when tracing the model.

explicitly gives a solution.

@thiagocrepaldi
Copy link

@ArthurZucker @BowenBao I believe we can close this issue now that #27931 was merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ONNX export failure for models invoking SDPA attention
4 participants